This is more-or-less a port of Jax MD into Dex to see how molecular dynamics looks in
Dex. For now, the structure of the two implementations is pretty close. However, details
look different.
Math
def truncate (x:Float):Float=case x < 0.0 ofTrue->-floor(-x)
False-> floor x
-- A mod that matches np.mod and python mod.
def pmod (x:Float) (y:Float):Float=
x - floor(x / y) * y
-- A mod that matches np.fmod and c fmod.
def fmod (x:Float) (y:Float):Float=
x - truncate (x / y) * y
defVec (dim:Type) [Ix dim] :Type= dim=>Float
def sq_norm {dim} (r:Vec dim) :Float=
sum $for i. r.i * r.i
def norm {dim} (r:Vec dim) :Float=
sqrt $ sq_norm r
Useful Quantities
This computes a size for a box of the given number of dimensions,
such that the given number of particles will fill it with the given
density.
def harmonic_soft_sphere (sigma:Float) (r:Float) :Float=
soft_sphere 1.0 2.0 sigma r
Here we have a naive pairwise energy function constructor, that
promotes a two-particle energy to a whole-system energy by just
applying it to every pair of distinct particles. We start with this
to have something to test with.
def pair_energy {dim n}
(energy:Float->Float)
(displacement:Displacement dim)
(r: n=>Vec dim)
:Float=
sum $for i. sum for j:(..<i).
energy $ norm $ displacement r.i r.(inject _ j)
def get_position {dim n} (state:FireDescentState n dim) : (n=>Vec dim) =
(MkFireDescentState {R,V,F, dt, alpha, n_pos}) = state
R
def fire_descent_init {dim n}
(dt:Float)
(alpha:Float)
(energy:Energy n dim)
(r: n=>Vec dim)
:FireDescentState n dim =
force =\rp.-(grad energy) rp
V=for i:n j:dim. 0.0
F= force r
n_pos:Int= 0
MkFireDescentState {R=r,V,F, dt, alpha, n_pos}
def fire_descent_step {dim n}
(shift:Shift dim)
(energy:Energy n dim)
(state:FireDescentState n dim)
:FireDescentState n dim =-- Constants that parameterize the FIRE algorithm.
-- TODO: Thread these constants through somehow.
-- dougalm@ is there a canonical way to do this?
dt_start = 0.1
dt_max = 0.4
n_min = 5
f_inc = 1.1
f_dec = 0.5
f_alpha = 0.99
alpha_start = 0.1
ε = 0.000000001
force =\r.-(grad energy) r
-- FIRE algorithm.
(MkFireDescentState {R,V,F, dt, alpha, n_pos}) = state
-- Do a Velocity-Verlet step.
R=for i. shift R.i (V.i *. dt +F.i *. pow dt 2)
F_old =FF= force RV=V+ dt * 0.5 .* (F_old +F)
-- Rescale the velocity.
F_norm = sqrt $ sum for (i, j). pow F.i.j 2
V_norm = sqrt $ sum for (i, j). pow V.i.j 2
V=V+ alpha .* (F*.V_norm / (F_norm + ε) -V)
-- Check whether the force is aligned with the velocity.
FdotV= sum for (i, j).F.i.j *V.i.j
-- Decide whether to increase the speed of the simulation or reduce it.
(n_pos, dt, alpha) =ifFdotV>= 0.0
then
dt' =if n_pos >= n_min then (min (dt * f_inc) dt_max) else dt
alpha' =if n_pos >= n_min then (alpha * f_alpha) else alpha
(n_pos + 1, dt', alpha')
else (0, dt * f_dec, alpha_start)
V=ifFdotV>= 0.0 thenVelse zero
MkFireDescentState { R,V,F, dt, alpha, n_pos }
Drawing
import png
import diagram
Now a tool to draw a two-dimensional system, where each particle is a
disk of given size.
TwoDimensions=Fin 2
def draw_system {n} radius (r: n=>VecTwoDimensions) :Diagram=
disks = concat_diagrams for i.
circle radius |> move_xy (r.i.(0 @TwoDimensions), r.i.(1 @TwoDimensions))
disks
Here's the initial energy we compute for our system.
:t energy R_init_small
Float32
energy R_init_small
74.69006
Initialize a simulation
state_small = fire_descent_init 0.1 0.1 energy R_init_small
and test one step of minimization. The energy decreases from the
initial, as expected:
energy $ get_position $ fire_descent_step free_shift energy state_small
71.78407
Now we can test that our code basically works by running 100 steps of
minimization.
%time
(state_small', energies) = scan state_small \i:(Fin 100) s.
s' = fire_descent_step (periodic_shift L_small) energy s
(s', energy $ get_position s')
Compile time: 637.460 ms
Run time: 779.614 ms
Here's how the energy decreases over time.
%time
:html show_plot $ y_plot energies
Compile time: 392.019 ms
Run time: 3.376 ms
Here's what the system looks like after minimization.
The above pair_energy function will compute the influence of every atom on
every other atom, regardless of how far apart they are.
To simulate more efficiently, we'd like to approximate the pairwise
energy with an energy that only includes contributions from atoms that
are close enough to each other that we wish not to neglect them.
This is a two-step operation:
Break the simulation volume into a grid of cells, and do a linear
pass over the atoms to group them by which cell each is in.
Traverse every pair of adjacent cells and compute energy terms for
every pair of atoms only in those cells, and no others.
Bounded lists
We start with an abstraction of an incrementally growable list. To
get O(1) insertion at the end, we (currently) have to give an upper
bound for the list's size.
-- TODO Can we encapsulate this BoundedList type as a `data` and still
-- define in-place mutation operations on it?
defBoundedList n [Ix n] a [Data a] = (n & (n => a))
def unsafe_next_index {n} [Ix n] (ix:n) : n =
unsafe_from_ordinal n $ ordinal ix + 1
def empty_bounded_list {n a} [Ix n,Data a] (dummy_val: a) :BoundedList n a =
(unsafe_from_ordinal _ 0,for _. dummy_val)
-- The point of a `BoundedList` is O(1) push by in-place mutation
def bd_push {h n a} [Ix n] (lst_ref:Ref h (BoundedList n a)) (x: a) : {State h} Unit=
sz_ref = fst_ref lst_ref
sz = get sz_ref
buf_ref = snd_ref lst_ref
if ordinal sz < size n
then
buf_ref!sz := x
sz_ref := unsafe_next_index sz
else
todo -- throw ()
-- Once we're done pushing, we can compact a `BoundedList` into a standard `List`.
def as_list {n a} (lst:BoundedList n a) :List a =
(lim, buf) = lst
n_result = ordinal lim
AsList _ $for i:(Fin n_result). buf.(unsafe_from_ordinal _ $ ordinal i)
Cell list
We define a single index for the whole grid.
defGridIx dim grid_size [Ix dim] = dim => (Fin grid_size)
A cell list is now just a BoundedList of the (indices of) the atoms
that appear in each cell in the grid.
defCellTable dim grid_size bucket_size atom_ix [Ix dim] [Data atom_ix] =GridIx dim grid_size =>BoundedList (Fin bucket_size) atom_ix
The neighbor list computation. The point of the exercise is that
this is not O(#atoms^2), but rather O(#cells) * 9 * O(#atoms per
cell^2), because it only considers atoms from adjacent cells as
potential neighbors.
In that configuration, we find this many pairs of neighbors:
(AsList k _) = as_list res
k
3090
Now that we have the concept of neighbor lists, we cen define a
variant of pair_energy that only considers atoms that the neighbor
list says are close.
def pair_energy_nl {dim n}
(energy:Float->Float)
(displacement:Displacement dim)
(r: n=>Vec dim)
(neighbors:List (n & n))
:Float=
(AsList k nbrs) = neighbors
sum for i.
a1ix, a2ix = nbrs.i
case (ordinal a1ix) < (ordinal a2ix) ofTrue-> energy $ norm $ displacement r.a1ix r.a2ix
False-> 0.0
energy $ get_position $ fire_descent_step free_shift energy state_small
71.78407
-- A helper for short-circuiting `any` computation
def fast_any {n eff} [Ix n] (f: n -> {|eff} Bool) : {|eff} Bool=
iter \ct.if ct >= size n
thenDoneFalseelseif f (ct @ n) thenDoneTrueelseContinue
And now that this basically works, we can package the whole thing up
as a simulation. We have another trick here: we compute the neighbor
list with an extra "halo", treating atoms as neighbors if they are
distance 1 + halo from each other, rather than just the interaction
range 1. This way, we only have to recompute the neighbor list when
some atom moves more than halo/2 away from where it was when the
neighbor list is computed, because otherwise it remains a sound
approximation.
-- TODO(Issue 1133) Can't use scan with a body that has an effect?
def simulate {atom_ix}
(displacement :DisplacementTwoDimensions)
halo_size
L
time [Ix time]
(state :FireDescentState atom_ix TwoDimensions)
: {IO} (FireDescentState atom_ix TwoDimensions& time =>Float) =
with_state (get_position state) \saved_atoms_ref.
nbrs = just_neighbor_list (1.0 + halo_size) L$ get_position state
(AsList k _) = nbrs
print $ show k <> " initial neighbor list size"
with_state nbrs \saved_neighbors_ref.
swap $ run_state state \s_ref.for i.
s = get s_ref
new_atoms = get_position s
rebuild = fast_any \i.
2 * norm (displacement (get saved_atoms_ref!i) new_atoms.i) > halo_size
if rebuild then
saved_atoms_ref := new_atoms
nbrs = just_neighbor_list (1.0 + halo_size) L new_atoms
saved_neighbors_ref := nbrs
(AsList k _) = nbrs
print $ show k <> " new neighbor list size"
nbrs = get saved_neighbors_ref
s' = fire_descent_step (periodic_shift L) (energy_nl L nbrs) s
s_ref := s'
energy_nl L nbrs $ get_position s'
Let's check that it works on our test system from before